# -*- coding: utf-8 -*-
"""SR-Unified.ipynb

Automatically generated by Colaboratory.

"""

!pip install functorch
!pip install munch
!pip install wandb
!wandb login

wandb_agent_mode = False
run_tests = True

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import time
import copy
torch.set_default_dtype(torch.float64)
from munch import Munch

"""# Wandb and parameterization"""

import wandb

args = {
  "RANDOM_SEED" : 1,
  "INP_DIM" : 10,
  "OUTP_DIM" : 2,
  "N_OF_INPS" : 100, #number of different inputs used both for computing diff and for
                #calculating expected utilities
  "G" : 5,
  "MANUAL_fCD" : True,
  "fCD_HIDDEN_DIM_LIST" : [4],
  "fCD_HIDDEN_TYPE_LIST" : ["LeakyReLU", "LeakyReLU"],
  "MODEL_HIDDEN_DIM_LIST" : [100, 50, 50],
  "MODEL_HIDDEN_TYPE_LIST" : ["LeakyReLU","LeakyReLU","LeakyReLU","LeakyReLU"],
  "NOISE_TYPE" : "normal",
  "NOISE_SIZE" : 0.05,
  "TEST_DIFFS_RANGE" : 0.1,
  "STEP_2_SELF_PLAY_PROB" : 0.5,
  "STEP_2_LR" : 0.02,
  "STEP_2_OPTIMIZER" : "Adam",
  "N_OF_STEPS_STEP_2" : 200,
  "NO_NOISE_STEP_2" : True,
  "LOLA_LR" : 0.0001,
  "LOLA_LA" : 0.001,
  "N_OF_LOLA_STEPS" : 0,
  "LOLA_EARLY_TERMINATION" : True,
  "LOLA_EARLY_TERMINATION_SHIELD" : 20000,
  "LOLA_LINEAR_LA_DECAY" : True,
  "LOG_BEHAVIOR_GRAPH_EVERY_N_STEPS" : 1000,
  "TAYLOR_LOLA" : False,
  "N_OF_STEPS_PER_TURN_STEP_3" : 3000,
  "STEP_3_LR" : 0.00005,
  "N_OF_TURNS_STEP_3" : 1000,
  "STEP_3_OPTIMIZER": "SGD", #Adam gives weird results here. So all experiments are now using SGD.
  "STEP_3_IMPROVEMENTS_ONLY": True,
  "STEP_3_LR_EXPONENT": 0  #between 0 (constant) and -1 (1/t), # all experiments currently set this to 0.
}
args = Munch.fromDict(args)

#Track parameters in wandb:
wandb_run = wandb.init(project="SR-unified", entity="",\
                        config = args, group= "v0")

#For sweeps:
args = wandb.config

torch.manual_seed(args.RANDOM_SEED)

"""# Neural Nets"""

class NeuralNet(nn.Module):
  def __init__(self, inp_dim, hidden_dim_list, hidden_type_list, outp_dim):
    super(NeuralNet,self).__init__()
    self.inp_dim = inp_dim
    self.outp_dim = outp_dim
    full_dims_list = [inp_dim] + hidden_dim_list
    self.layers = nn.ModuleList()
    for i in range(len(full_dims_list)-1):
      self.layers.append(nn.Linear(full_dims_list[i], full_dims_list[i+1]))
      if hidden_type_list[i] == "LeakyReLU":
        self.layers.append(nn.LeakyReLU())
      elif hidden_type_list[i] == "LogSigmoid":
        self.layers.append(nn.LogSigmoid())
      else:
        assert False
    self.layers.append(nn.Linear(full_dims_list[-1], outp_dim))

  def forward(self, x):
    result = x
    for layer in self.layers:
      result = layer(result)
    return result

"""# Implementation of LOLA """

import torch
import copy
from functorch import make_functional, grad_and_value
from functools import partial

def __get_gradient(value, params):
  grads = torch.autograd.grad(value, params, create_graph=True)
  return grads

def lola_update(loss_fn, models, alpha, beta, algo='exact_lola'):
  # loss_fn takes in list of 2 NN model forward functions and outputs loss values for each player
  # (i.e. you CANNOT CALL model.forward() in the loss function. INSTEAD, JUST USE model(x))
  # models is a list of two models for each player
  # alpha is learning rate
  # beta is opponent shaping / lookahead rate
  # algo in ['taylor_lola', 'exact_lola'] specifies whether to use taylor or exact lola
  # returns a list of current losses for both players

  assert len(models) == 2, 'more than 2 players not implemented yet'

  for model in models:
    model.zero_grad()

  n = 2

  #start by making models functional
  funcs_params = [make_functional(model) for model in models]
  funcs = [func_model[0] for func_model in funcs_params]
  params = [func_model[1] for func_model in funcs_params]

  assert len(params[0]) == len(params[1]), 'different number of params for different players not implemented yet'

  n_params = len(params[0])

  def _Ls(params):
    _models = [partial(funcs[i],params[i]) for i in range(n)]
    return loss_fn(_models)

  losses = _Ls(params)

  if algo == 'exact_lola':
    opponent_steps = [__get_gradient(losses[i], params[i]) for i in range(n)]

    # update opponent's parameters with naive learning
    updated_params = [tuple(params[i][j] - beta*opponent_steps[i][j] for j in range(n_params)) for i in range(n)]


    # calculate loss for players again, using updated params
    lookahead_losses = [_Ls([params[0],updated_params[1]])[0],
                        _Ls([updated_params[0],params[1]])[1]]

    grads = [__get_gradient(lookahead_losses[i], params[i]) for i in range(n)]

  elif algo == 'taylor_lola':
    grad_L = [[__get_gradient(losses[j], params[i]) for j in range(n)] for i in range(n)]

    terms = [sum([torch.dot(grad_L[j][i][v].flatten(), grad_L[j][j][v].flatten())
                  for j in range(n) if j != i for v in range(n_params)]) for i in range(n)]

    second_order_grads = [__get_gradient(terms[i], params[i]) for i in range(n)]
    grads = [tuple(grad_L[i][i][j]-beta*second_order_grads[i][j] for j in range(n_params)) for i in range(n)]

  else:
    raise Exception(f'invalid algo specification: {algo}')

  # update all players' params for real
  with torch.no_grad():
    for i in range(n):
        for j, param in enumerate(models[i].parameters()):
          param.sub_(alpha*grads[i][j])

  return losses

"""# Definition of the diff game"""

from random import randint

#This is used to randomly generate functions f_C and f_D.
def random_sum_sin(inps):
  assert inps.shape == (args.N_OF_INPS, args.INP_DIM)
  selector = torch.Tensor([randint(0, 1) for i in range(args.INP_DIM)])
  #print("Selector:",selector)
  #print("Inps",inps)
  sums = inps @ selector
  #print("Sums:", sums)
  outpi = torch.sin(sums)
  #print("Sins:", outpi)
  outpi = torch.reshape(outpi, (args.N_OF_INPS, 1))
  #print("outpi:", outpi)
  return outpi

def generate_noise():
  if args.NOISE_TYPE == "uniform":
    return args.NOISE_SIZE * torch.rand(args.N_OF_INPS, 1)
  elif args.NOISE_TYPE == "normal":
    return torch.normal(mean=0.0,std=args.NOISE_SIZE, size=(args.N_OF_INPS, 1))
  assert False

def distance(t1, t2):
  assert t1.shape[1] == args.OUTP_DIM
  assert t2.shape[1] == args.OUTP_DIM
  return torch.mean(torch.linalg.norm(t1-t2, dim=1, ord=2))

#Input
inps = torch.rand(args.N_OF_INPS, args.INP_DIM)
test_diffs = args.TEST_DIFFS_RANGE * torch.rand(args.N_OF_INPS, 1)

if args.NO_NOISE_STEP_2:
  test_diff_noises = torch.zeros(args.N_OF_INPS, 1)
  diff_noise_vals_pl1 = torch.zeros(args.N_OF_INPS, 1)
  diff_noise_vals_pl2 = torch.zeros(args.N_OF_INPS, 1)
else:
  test_diff_noises, diff_noise_vals_pl1, diff_noise_vals_pl2 = generate_noise(), generate_noise(), generate_noise()

test_diffs += test_diff_noises
test_inps = torch.cat((test_diffs, inps), dim=1)


if args.MANUAL_fCD:
  #instead of generating a random neural net to generate fD and fC,
  #we generate manual fC and fD functions. This is to make sure that these
  #functions are sensible and not just ~constant, for example. 
  fD_vals = torch.cat([random_sum_sin(inps) for _ in range(args.OUTP_DIM)], dim=1)
  fC_vals = torch.cat([random_sum_sin(inps) for _ in range(args.OUTP_DIM)], dim=1)
  print("fD_vals",fD_vals)
else:
  fD_model = NeuralNet(args.INP_DIM, args.fCD_HIDDEN_DIM_LIST, args.fCD_HIDDEN_TYPE_LIST, args.OUTP_DIM)
  fD_vals = fD_model.forward(inps)
  print(fD_vals)
  del fD_model
  fC_model = NeuralNet(args.INP_DIM, args.fCD_HIDDEN_DIM_LIST, args.fCD_HIDDEN_TYPE_LIST, args.OUTP_DIM)
  fC_vals = fC_model.forward(inps)
  print(fC_vals)
  del fC_model

CD_diff = distance(fC_vals,fD_vals).item()

def diff(model1, model2):
  assert model1.inp_dim ==  model2.inp_dim == args.INP_DIM+1
  y1 = model1.forward(test_inps)
  y2 = model2.forward(test_inps)
  return distance(y1, y2)/CD_diff

#BEGIN TEST
#Tests whether the distance of a model to itself is small.
if run_tests:
  model = NeuralNet(args.INP_DIM+1,args.MODEL_HIDDEN_DIM_LIST, args.MODEL_HIDDEN_TYPE_LIST,args.OUTP_DIM)
  assert -1e-10< diff(model,model).item() < 1e-10
  del model
#END Test

def loss(models, false_diff = None):
  model1 = models[0]
  model2 = models[1]
  d = diff(model1,model2)
  if false_diff is not None:
    d = torch.Tensor([false_diff]).repeat(args.N_OF_INPS,1)
  d1 = d + diff_noise_vals_pl1
  d2 = d + diff_noise_vals_pl2
  inps1 = torch.cat((d1,inps), dim=1)
  inps2 = torch.cat((d2,inps), dim=1)
  y1 = model1.forward(inps1)
  y2 = model2.forward(inps2)
  loss1 = torch.reshape((distance(y1,fD_vals) + args.G * distance(y2,fC_vals))/CD_diff,(1,))
  loss2 = torch.reshape((distance(y2,fD_vals) + args.G * distance(y1,fC_vals))/CD_diff,(1,))
  return [loss1,loss2]

#BEGIN TEST
if run_tests:
  model = NeuralNet(args.INP_DIM+1,args.MODEL_HIDDEN_DIM_LIST, args.MODEL_HIDDEN_TYPE_LIST,args.OUTP_DIM)
  assert -0.005< loss([model,model])[0].item() - loss([model,model])[1].item() < 0.005
  del model
#END Test

# Now we implement a copy of the original loss function, except that it doesn't
# use model.forward(inps). Instead, it just uses model(inps).
# This is necessary for our implementations of exact and Taylor LOLA.
# We still need the original function for everything else. 

def diff_without_forward(model1, model2):
  y1 = model1(test_inps)
  y2 = model2(test_inps)
  return distance(y1, y2)/CD_diff

#BEGIN TEST
if run_tests:
  model = NeuralNet(args.INP_DIM+1,args.MODEL_HIDDEN_DIM_LIST, args.MODEL_HIDDEN_TYPE_LIST,args.OUTP_DIM)
  assert -1e-10< diff(model,model).item() < 1e-10
  del model
#END Test

def loss_without_forward(models, false_diff = None):
  model1 = models[0]
  model2 = models[1]
  d = diff_without_forward(model1,model2)
  if false_diff is not None:
    d = torch.Tensor([false_diff]).repeat(args.N_OF_INPS,1)
  d1 = d + diff_noise_vals_pl1
  d2 = d + diff_noise_vals_pl2
  inps1 = torch.cat((d1,inps), dim=1)
  inps2 = torch.cat((d2,inps), dim=1)
  y1 = model1(inps1)
  y2 = model2(inps2)
  loss1 = torch.reshape((distance(y1,fD_vals) + args.G * distance(y2,fC_vals))/CD_diff,(1,))
  loss2 = torch.reshape((distance(y2,fD_vals) + args.G * distance(y1,fC_vals))/CD_diff,(1,))
  return [loss1,loss2]

#BEGIN TEST
if not wandb_agent_mode:
  model = NeuralNet(args.INP_DIM+1,args.MODEL_HIDDEN_DIM_LIST, args.MODEL_HIDDEN_TYPE_LIST,args.OUTP_DIM)
  assert -0.005< loss_without_forward([model,model])[0].item() - loss_without_forward([model,model])[1].item() < 0.005
  del model
#END Test

"""
# Analysis and logging functions"""

import copy as copy

#This method tests whether a given model is a best response to another given model.
#It does so by perturbing the given model and seeing whether this yields lower loss.
def best_response_test(model, opponent_model, n_of_perturbations=10000, epsilon=0.0001):
  n_of_improvements = 0
  old_loss = loss([model, opponent_model])[0].item()
  for _ in range(n_of_perturbations):
    alt_model = copy.deepcopy(model)
    state_dict = alt_model.state_dict()
    for name, param in state_dict.items():
      # Transform the parameter as required.
      transformed_param = param + epsilon*torch.rand(param.shape) - epsilon/2
      # Update the parameter.
      param.copy_(transformed_param)
    if loss([alt_model, opponent_model])[0].item() < old_loss:
      n_of_improvements += 1
  return 1-(n_of_improvements / n_of_perturbations)

#This asks: if artificially both agents received a slightly higher diff input,
# how would this affect their utility. If the models successfully set incentives
# on each other, this should be positive.
# (I currently don't track this anymore, because it wasn't super insightful in
#  the past and because it broke for some technical reason.)
def loss_diff_rate(model, opponent, epsilon = 0.001):
  true_losses = loss([model, opponent])[0].item()
  incr_loss = loss([model, opponent], diff(model,opponent).item()+epsilon)[1].item()
  return (incr_loss-true_losses)/epsilon

def model_behavior_stats(model, n_of_points = 300):
  test_diffs = [torch.tensor([i/n_of_points]) for i in range(n_of_points)]
  diffs_to_fC = []
  diffs_to_fD = []
  for test_diff in test_diffs: 
    full_inps = torch.cat((torch.reshape(test_diff.repeat(args.N_OF_INPS),(args.N_OF_INPS,1)),inps), dim=1)
    outps = model.forward(full_inps)
    diffs_to_fC.append(distance(outps,fC_vals).item())
    diffs_to_fD.append(distance(outps,fD_vals).item())
  return test_diffs, diffs_to_fC, diffs_to_fD

# Deprecated -- we now mostly use wandb for this
def behavior_graph(model, n_of_points = 300):
  test_diffs, diffs_to_fC, diffs_to_fD = model_behavior_stats(model, n_of_points)
  plt.plot(test_diffs, diffs_to_fD, label="dist to D")
  plt.plot(test_diffs, diffs_to_fC, label="dist to C")
  plt.plot(test_diffs, [diffs_to_fD[i] + diffs_to_fC[i] for i in range(n_of_points)],\
            label="dist to C + dist to D")
  plt.plot(test_diffs, [CD_diff]*n_of_points, label="dist C to D")
  plt.legend()
  plt.xlabel("ag_diff")
  plt.ylabel("dist to C/D")
  plt.ylim(bottom=0)
  plt.show()

def log_behavior_graph_to_wandb(model, id, n_of_points=300):
  test_diffs, diffs_to_fC, diffs_to_fD = model_behavior_stats(model, n_of_points)

  wandb.log({id : wandb.plot.line_series(
                       xs=[d.item() for d in test_diffs], 
                       ys=[diffs_to_fD, diffs_to_fC,\
                           [y+z for (y, z) in zip(diffs_to_fD, diffs_to_fC)],\
                           [CD_diff]*len(test_diffs)],
                       keys=["dist to D", "dist to C", "dist to C + dist to D", "dist C to D"],
                       title=id,
                       xname="ag_diff")})
  
def behavior_change_graph(old_model, new_model, diff_val=None, n_of_points = 300):
  test_diffs, diffs_to_fC_old, diffs_to_fD_old = model_behavior_stats(old_model, n_of_points)
  test_diffs, diffs_to_fC_new, diffs_to_fD_new = model_behavior_stats(new_model, n_of_points)
  plt.plot(test_diffs, [diffs_to_fD_new[i] - diffs_to_fD_old[i] for i in range(n_of_points)], label="increase in dist to D")
  plt.plot(test_diffs, [diffs_to_fC_new[i] - diffs_to_fC_old[i] for i in range(n_of_points)], label="increase in dist to C")
  #plt.plot(test_diffs, [diffs_to_fD[i] + diffs_to_fC[i] for i in range(n_of_points)],\
  #          label="increase in: dist to C + dist to D")
  #plt.plot(test_diffs, [CD_diff]*n_of_points, label="dist C to D")
  plt.axhline(y=0)
  if diff_val is not None:
    plt.axvline(x=diff_val, label="current diff between models[0] and models[1]")
  plt.legend(bbox_to_anchor=(1.1, 1.05))
  plt.xlabel("ag_diff")
  #plt.ylabel("dist to C/D")
  #plt.ylim(bottom=0)
  plt.show()


def print_diffs_to_random(model):
    opponents_lst = [NeuralNet(args.INP_DIM+1,args.MODEL_HIDDEN_DIM_LIST,args.MODEL_HIDDEN_TYPE_LIST,args.OUTP_DIM)\
                        for i in range(30)]
    trunc_diffs = [truncate(diff(model, opponent).item(),4) for opponent in opponents_lst]
    trunc_diffs.sort()
    print("diffs to random:")
    print(trunc_diffs)

"""# Step 2 pre-training"""

def step2_loss(model, opponents_lst, self_play_prob=1/2):
  loss_val = torch.zeros(1)
  loss_val += self_play_prob * loss([model,model])[0]
  for opponent in opponents_lst:
    loss_val += (1-self_play_prob) * (1/len(opponents_lst)) * loss([model,opponent])[0]
  return loss_val

def step2(model, n_of_steps=100, self_play_prob=1/2, n_of_opponents_per_epoch=100,
            print_progress=True, wandb_track = True):
  if args.STEP_2_OPTIMIZER == "Adam":
    optimizer = optim.Adam(model.parameters(), lr=args.STEP_2_LR)
  elif args.STEP_2_OPTIMIZER == "SGD":
    optimizer = optim.SGD(model.parameters(), lr=args.STEP_2_LR)
  else:
    assert False
  for step_no in range(n_of_steps):
    optimizer.zero_grad()
    opponents_lst = None
    if self_play_prob<1:
      opponents_lst = [NeuralNet(args.INP_DIM+1,args.MODEL_HIDDEN_DIM_LIST, args.MODEL_HIDDEN_TYPE_LIST,args.OUTP_DIM)\
                        for i in range(n_of_opponents_per_epoch)]
    loss_val = step2_loss(model, opponents_lst, self_play_prob)
    loss_val.backward(retain_graph = True)

    #for logging, calculate the avg of the entries of the gradient
    #Note that this actually the average of averages of parameters in each layer.
    gradient_avg = 0
    count = 0
    for param in model.parameters():
      count += 1
      gradient_avg += torch.mean(torch.abs(param.grad)).item()
    gradient_avg = gradient_avg/count

    optimizer.step()

    if wandb_track:
      metrics = {
        'step_2_loss': loss_val[0].item(),
        'phase': "Step 2",
        'loss of (D,D)': args.G,
        'loss of (C,C)': 1,
        'gradient avg step 2': gradient_avg
      }
      wandb.log(metrics)

    if print_progress:
      print('Step 2 Epoch {}, Loss {}'.format(step_no, loss_val.item()))
    

### BEGIN TEST (of Step 2)
if run_tests:
  #These tests are "dirty" in the sense that they don't work for
  #arbitrary values of args.

  model = NeuralNet(args.INP_DIM+1,args.MODEL_HIDDEN_DIM_LIST, args.MODEL_HIDDEN_TYPE_LIST,args.OUTP_DIM)
  step2(model, n_of_steps=5000, self_play_prob=1, print_progress=False, wandb_track=False)
  #behavior_graph(model)
  #print_diffs_to_random(model)
  #print("loss:", loss(model,model))
  #print("CD_diff:", CD_diff)
  assert 1<=loss([model,model])[0].item() < 1.1
  del model

  model = NeuralNet(args.INP_DIM+1,args.MODEL_HIDDEN_DIM_LIST, args.MODEL_HIDDEN_TYPE_LIST,args.OUTP_DIM)
  step2(model, n_of_steps=5000, n_of_opponents_per_epoch=10, self_play_prob=0, print_progress=False, wandb_track=False)
  #behavior_graph(model)
  #print_diffs_to_random(model)
  print("loss:", loss([model,model]))
  print("CD_diff:", CD_diff)
  assert args.G-0.3<=loss([model,model],false_diff=0.7)[0].item() < args.G+0.3
  del model
### END TEST


"""# Step 3"""

import copy as copy
import math as math

def step3(model1, model2=None, n_of_turns=100, print_progress = False,\
          unilateral = False, self_play=False, n_of_steps_per_turn=1,
          print_turn_numbers=True, optimizer_type=args.STEP_3_OPTIMIZER,
          early_termination_at_mutual_D=True):
  models = [model1, model2]
  agents_to_update = [0]
  if self_play:
    unilateral = True
  if not unilateral:
    agents_to_update.append(1)
  optimizers = [None] * len(agents_to_update)
  
  #variables used for analysis
  n_of_successful_steps = [0] * len(agents_to_update)

  pre_turn_losses = loss(models)
  pre_turn_losses = (pre_turn_losses[0].item(), pre_turn_losses[1].item())

  for turn_no in range(n_of_turns):
    print ("Turn ",turn_no+1," of ",n_of_turns, "Losses: ", pre_turn_losses[0],", ",pre_turn_losses[1])
    if self_play:
      models[1] = copy.deepcopy(models[0])
    for i in agents_to_update:
      if optimizer_type == "Adam":
        optimizers[i] = optim.Adam(models[i].parameters(), lr=args.STEP_3_LR)
      elif optimizer_type == "SGD":
        optimizers[i] = optim.SGD(models[i].parameters(), lr=args.STEP_3_LR)
      else:
        assert False

      for step_no in range(n_of_steps_per_turn):
        #reduce learning rate, which according to
        #https://stackoverflow.com/questions/48324152/pytorch-how-to-change-the-learning-rate-of-an-optimizer-at-any-given-moment-no
        #can be done as follows.
        step_size = ((1+turn_no*n_of_steps_per_turn+step_no) ** args.STEP_3_LR_EXPONENT) * args.STEP_3_LR * torch.rand(1).item()
        for g in optimizers[i].param_groups:
          g['lr'] = step_size
        old_model = copy.deepcopy(models[i])
        
        models[i].zero_grad()

        loss_vals = loss(models)
        loss_vals[i].backward(retain_graph = True)

        #for analysis/logging: calculate the avg of the entries of the gradient and the parameters themselves
        gradient_avg = 0
        param_avg = 0
        count = 0
        for param in models[i].parameters():
          count += 1
          gradient_avg += torch.mean(torch.abs(param.grad)).item()
          param_avg += torch.mean(torch.abs(param)).item()
        gradient_avg = gradient_avg/count
        param_avg = param_avg/count

        optimizers[i].step()

        new_losses = loss(models)

        #If they defect, give up (to save compute)
        if early_termination_at_mutual_D and (new_losses[0].item() >= args.G\
                                              or new_losses[1].item() >= args.G\
            or math.isnan(loss_vals[0].item()) or math.isnan(loss_vals[1].item())):
          wandb_run.tags = wandb_run.tags + ("Fail",)
          return

        #If gradient step increased loss, then revert:
        loss_decrease = loss_vals[i].item() - new_losses[i].item()
        if args.STEP_3_IMPROVEMENTS_ONLY and loss_decrease < 0.0:
          models[i].load_state_dict(old_model.state_dict())
        else:
          n_of_successful_steps[i] += 1
          if print_progress:
            print("Agent ",i," just made an update.")
            print("diff between agents: ",diff(models[0],models[1]))
            print("Current rates of increasing diff:", loss_diff_rate(models[0], models[1]))
            print('Step 3 Ag {} Turn {}, Losses [0, 1]: {}, {}'.format(i, turn_no,\
                                                              loss(models)[0].item(),\
                                                              loss(models)[1].item()))

        #new_losses = loss(models[0], models[1])
        #new_losses = (new_losses[0].item(), new_losses[1].item())
        #loss_diff_rates = loss_diff_rate(models[0], models[1])
        #$loss_diff_rates = (loss_diff_rates[0].item(), loss_diff_rates[1].item())
        #log in wandb:
        #metrics = {
        #    'loss_0': new_losses[0],
        #    'loss_1': new_losses[1],
        #    'agent_diffs': diff(models[0], models[1]).item(),
        #    'turn_no': turn_no,
        #    'step_no': step_no,
        #    'loss_diff_rate_0': loss_diff_rates[0],
        #    'loss_diff_rate_1': loss_diff_rates[1],
        #    'phase': "Step 3",
        #    'latest_update_by': i,
        #    'loss of (D,D)': args.G,
        #    'loss of (C,C)': 1,
        #    'gradient_avg'+str(i): gradient_avg,
        #    'param_avg'+str(i): param_avg,
        #    'step_3_step_size': step_size,
        #    'loss_decrease'+str(i): loss_decrease
        #}
        #wandb.log(metrics)

      if print_progress:
          behavior_graph(models[i])
          behavior_change_graph(old_model=old_model, new_model=models[i],\
                                diff_val=diff(models[0],models[1]).item())

    #log change in loss throughout step
    new_losses = loss(models)
    new_losses = (new_losses[0].item(), new_losses[1].item())
    #loss_diff_rates = loss_diff_rate(models[0], models[1])
    #loss_diff_rates = (loss_diff_rates[0].item(), loss_diff_rates[1].item())
    wandb.log({'loss_0': new_losses[0],
               'loss_1': new_losses[1],
               'turn_no': turn_no,
    #           'loss_diff_rate_0': loss_diff_rates[0],
    #           'loss_diff_rate_1': loss_diff_rates[1],
               'phase': "Step 3",
               'loss of (D,D)': args.G,
               'loss of (C,C)': 1,
               'agent_diffs': diff(models[0], models[1]).item(),
               'turn_loss_decrease_0' : pre_turn_losses[0]-new_losses[0],
               'turn_loss_decrease_1' : pre_turn_losses[1]-new_losses[1],
               'loss_diff': abs(new_losses[0]-new_losses[1])})
    pre_turn_losses = new_losses

  print("Number of successful steps:", n_of_successful_steps)

"""# Running the experiments

## Initialization
"""

list_of_agents_to_create = [0, 1]
models = []
for i in list_of_agents_to_create:
  models.append(NeuralNet(args.INP_DIM+1,args.MODEL_HIDDEN_DIM_LIST, args.MODEL_HIDDEN_TYPE_LIST, args.OUTP_DIM))

"""## Pre-train ("Step 2")"""

for i in list_of_agents_to_create:
  step2(models[i], n_of_steps=args.N_OF_STEPS_STEP_2, self_play_prob=args.STEP_2_SELF_PLAY_PROB)
  #behavior_graph(models[i])
  #print_diffs_to_random(models[i])
if len(list_of_agents_to_create)==2:
  print("diff after Step 2: ",diff(models[0],models[1]))

log_behavior_graph_to_wandb(models[0], id="Model 0 after Step 2")
log_behavior_graph_to_wandb(models[1], id="Model 1 after Step 2")

print(best_response_test(models[0],models[1], n_of_perturbations=1000))
print(best_response_test(models[1],models[0], n_of_perturbations=1000))

# If step 2 used 0 noise, we now need to set the noise to something non-zero.
if args.NO_NOISE_STEP_2:
    test_diff_noises, diff_noise_vals_pl1, diff_noise_vals_pl2 = generate_noise(), generate_noise(), generate_noise()

"""## LOLA"""

def is_successful(loss_hist, threshold_factor=0.8, backward_horizon=100):
  for i in range(1,backward_horizon+1):
    if loss_hist[-i][0] > threshold_factor * args.G:
      return False


import math

early_termination = False
old_loss_vals = (0,0)
loss_hist = [] #for testing convergence for deciding whether to terminate early
for step_no in range(args.N_OF_LOLA_STEPS):
  if args.LOLA_LINEAR_LA_DECAY:
    la = args.LOLA_LA * (1- step_no/args.N_OF_LOLA_STEPS)
  else:
    la = args.LOLA_LA
  if not early_termination:
    if args.TAYLOR_LOLA:
      loss_vals = lola_update(loss_without_forward,models,beta=la,alpha=args.LOLA_LR, algo='taylor_lola')
    else:
      loss_vals = lola_update(loss_without_forward,models,beta=la,alpha=args.LOLA_LR, algo='exact_lola')
    #loss_diff_rates = loss_diff_rate(models[0], models[1])
    #loss_diff_rates = (loss_diff_rates[0].item(), loss_diff_rates[1].item())
    loss_hist.append((loss_vals[0].item(), loss_vals[1].item()))
    wandb.log({"loss_0": loss_vals[0].item(),
               "loss_1": loss_vals[1].item(),
               'turn_loss_decrease_0' : old_loss_vals[0]-loss_vals[0].item(),
               'turn_loss_decrease_1' : old_loss_vals[1]-loss_vals[1].item(),
               'loss of (D,D)': args.G,
               'loss of (C,C)': 1,
     #          'loss_diff_rate_0': loss_diff_rates[0],
     #          'loss_diff_rate_1': loss_diff_rates[1],
               'agent_diffs': diff(models[0], models[1]).item()})
    old_loss_vals = (loss_vals[0].item(), loss_vals[1].item())
    if step_no % args.LOG_BEHAVIOR_GRAPH_EVERY_N_STEPS == 0:
      log_behavior_graph_to_wandb(models[0], id="Model 0 after "+str(step_no)+" steps of LOLA")
      log_behavior_graph_to_wandb(models[1], id="Model 1 after "+str(step_no)+" steps of LOLA")
    print("Turn", step_no, ". Current loss:", loss_vals[0].item(), ",", loss_vals[1].item())
    if (loss_vals[0].item()> args.G and loss_vals[1].item()> args.G and args.LOLA_EARLY_TERMINATION and (args.N_OF_STEPS_STEP_2>1 or step_no>args.LOLA_EARLY_TERMINATION_SHIELD))\
            or math.isnan(loss_vals[0].item()) or math.isnan(loss_vals[1].item()) or math.isinf(loss_vals[0].item()) or math.isinf(loss_vals[1].item()):
      early_termination = True
      wandb_run.tags = wandb_run.tags + ("Fail",)
      wandb.log({
          "n_of_lola_steps_before_failure" : step_no
      })
if not early_termination:
  wandb.log({
          "n_of_lola_steps_before_failure" : args.N_OF_LOLA_STEPS
  })

log_behavior_graph_to_wandb(models[0], id="Model 0 after LOLA")
log_behavior_graph_to_wandb(models[1], id="Model 1 after LOLA")

print (best_response_test(models[0],models[1]))
print (best_response_test(models[1],models[0]))

"""## Mutual best response learning ("Step 3")"""

step3(models[0], models[1], n_of_turns = args.N_OF_TURNS_STEP_3,\
      print_progress=True, self_play=False, n_of_steps_per_turn = args.N_OF_STEPS_PER_TURN_STEP_3,\
      early_termination_at_mutual_D=True)

log_behavior_graph_to_wandb(models[0], id="Model 0 after Step 3")
log_behavior_graph_to_wandb(models[1], id="Model 1 after Step 3")

print (best_response_test(models[0],models[1]))
print (best_response_test(models[1],models[0]))

wandb.finish()

